package dep

import (
	"bytes"
	"fmt"
	"go/ast"
	"go/format"
	"go/token"
	"io"
	"os"
	"path/filepath"

	"golang.org/x/tools/imports"
)

type Autodig struct {
	importHandler ImportHandler
	scanDirs      []string
	outputDir     string
	cmdTag        string
}

func NewAutodig(scanDirs []string, outputDir string, cmdTag string) *Autodig {
	return &Autodig{importHandler: NewImportHandler(), scanDirs: scanDirs, outputDir: outputDir, cmdTag: cmdTag}
}

func (a *Autodig) GenDigFile() error {
	err := a.handleParam()
	if err != nil {
		return err
	}
	err = os.Remove(a.outputDir)
	if err != nil {
		fmt.Println("delete autodig file warn", err.Error())
	}
	decls, outputPkgName, err := a.genDecls()
	if err != nil {
		fmt.Println(err)
		return err
	}
	outputfile, err := os.Create(a.outputDir)
	if err != nil {
		fmt.Println(err)
		return err
	}
	defer outputfile.Close()
	err = a.write(outputfile, outputPkgName, decls)
	if err != nil {
		fmt.Println(err)
		return err
	}
	return nil
}

func (a *Autodig) handleParam() error {
	err := a.absParam()
	if err != nil {
		return err
	}
	for i, scanDir := range a.scanDirs {
		if scanDir[len(scanDir)-1] == '/' {
			a.scanDirs[i] = scanDir[:len(scanDir)-1]
		}
	}

	if a.outputDir[len(a.outputDir)-3:] != ".go" {
		if a.outputDir[len(a.outputDir)-1] != '/' {
			a.outputDir = fmt.Sprintf("%s/", a.outputDir)
		}
		a.outputDir = fmt.Sprintf("%sautodig.go", a.outputDir)
	}
	return nil
}

func (a *Autodig) absParam() error {
	var err error
	for i, scanDir := range a.scanDirs {
		a.scanDirs[i], err = absPath(scanDir)
		if err != nil {
			return err
		}
	}
	a.outputDir, err = absPath(a.outputDir)
	if err != nil {
		return err
	}
	return nil
}

func (a *Autodig) genDecls() ([]ast.Decl, string, error) {
	files, err := a.getAllFiles(a.scanDirs)
	if err != nil {
		return nil, "", fmt.Errorf("getAllFiles err: %v ", err)
	}
	// 第一次遍历，获取所有imports和imports别名
	importCtx, err := a.importHandler.GetAllImports(files, a.outputDir)
	if err != nil {
		return nil, "", fmt.Errorf("getAllImports err: %v ", err)
	}
	// 第二次遍历, 构建方法们
	decls, err := NewFileBuilder(importCtx).BuildDecls(files, importCtx, a.cmdTag)
	if err != nil {
		return nil, "", fmt.Errorf("buildDecls err: %v ", err)
	}
	return decls, importCtx.outputPkgName, nil
}

func (a *Autodig) write(wr io.Writer, pkgName string, funcs []ast.Decl) error {
	header := fmt.Sprintf(`// Code generated by autodig. DO NOT EDIT.
package %s
`, pkgName)
	buffer := bytes.NewBufferString(header)
	for _, fn := range funcs {
		err := a.astToGo(buffer, fn)
		if err != nil {
			return err
		}
	}
	var err error
	bytes, err := imports.Process("", buffer.Bytes(), nil)
	if err != nil {
		return err
	}
	// bytes := buffer.Bytes()
	_, err = wr.Write(bytes)
	return err
}

func (a *Autodig) astToGo(dst *bytes.Buffer, node interface{}) error {
	addNewLine := func() error {
		err := dst.WriteByte('\n')
		if err != nil {
			return err
		}
		return nil
	}
	err := addNewLine()
	if err != nil {
		return err
	}

	err = format.Node(dst, token.NewFileSet(), node)
	if err != nil {
		return err
	}
	return nil
}

func containsString(slice []string, target string) bool {
	for _, each := range slice {
		if each == target {
			return true
		}
	}
	return false
}

func absPath(path string) (string, error) {
	if filepath.IsAbs(path) {
		return path, nil
	}
	ret := path
	var err error
	ret, err = filepath.Abs(path)
	if err != nil {
		return "", fmt.Errorf("abs path err, path:%s, err: %v", path, err)
	}
	return ret, nil
}
